# Standard
from enum import Enum

import torch
from vllm.config import ParallelConfig
from vllm.logger import logger

from vllm_ascend.distributed.kvpool.backend.backend import Backend


class MmcDirect(Enum):
    COPY_L2G = 0
    COPY_G2L = 1
    COPY_G2H = 2
    COPY_H2G = 3


class MemcacheBackend(Backend):

    def __init__(self, parallel_config: ParallelConfig):
        try:
            from memcache import DistributedObjectStore  # type: ignore
        except ImportError as e:
            raise ImportError(
                "Please install memcache by following the instructions at "
                "https://gitee.com/ascend/memfabric_hybrid "  # noqa: E501
                "to run vLLM with MemcacheConnector.") from e
        try:
            self.rank = parallel_config.rank
            self.store = DistributedObjectStore()
            res = self.store.init(self.rank)
            assert res == 0
        except ValueError as e:
            logger.error("Configuration loading failed: %s", e)
            raise
        except Exception as exc:
            logger.error(
                "An error occurred while loading the configuration: %s", exc)
            raise

    def set_device(self):
        device = torch.device(f"npu:{self.rank}")
        torch.npu.set_device(device)

    def register_buffer(self, ptrs: list[int], sizes: list[int]):
        for ptr, size in zip(ptrs, sizes):
            ret_value = self.store.register_buffer(ptr, size)
            if ret_value != 0:
                raise RuntimeError("Memcache memory registration failed.")

    def exists(self, keys: list[str]) -> list[int]:
        return self.store.batch_is_exist(keys)

    def get(self, key: list[str], addr: list[list[int]],
            size: list[list[int]]):
        try:
            res = self.store.batch_get_into_layers(key, addr, size,
                                                   MmcDirect.COPY_G2L.value)
            for value in res:
                if value != 0:
                    logger.error(f"Failed to get key {key},res:{res}")
        except Exception as e:
            logger.error(f"Failed to get key {key}. {e}")

    def put(self, key: list[str], addr: list[list[int]],
            size: list[list[int]]):
        try:
            res = self.store.batch_put_from_layers(key, addr, size,
                                                   MmcDirect.COPY_L2G.value)
            for value in res:
                if value != 0:
                    logger.error(f"Failed to get key {key},res:{res}")
        except Exception as e:
            logger.error(f"Failed to put key {key},error:{e}")
